import os

from MADCluster_UTILS import seed_everything, set_device, set_feature_size, he_init_normal, get_loader_segment, EarlyStopping
from MADCluster_MODEL import Time_Series_Deep_SVDD
from MADCluster_SOLVER import main_trainer

import argparse

import torch
import torch.nn as nn

import warnings
warnings.filterwarnings("ignore")

seed_everything()

def get_args_parser():
    parser = argparse.ArgumentParser('PyTorch Training', add_help=False)
    
    # Model parameters
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--window_size', default=100, type=int)
    parser.add_argument('--nu', default=0.01, type=float)
    parser.add_argument('--objective', default='soft-boundary', type=str)
    parser.add_argument('--cluster_num', default=2, type=int)
    parser.add_argument('--hidden_structs', nargs='+', type=int, default=[32, 64, 128], help='List of hidden dimensions for each layer')
    parser.add_argument('--dilations', nargs='+', type=int, default=[1, 2, 3], help='List of dilation rates for each layer')
    
    # Optimizer parameters
    parser.add_argument('--smoothing_factor', default=0.1, type=float)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--lr_lambda', default=0.99, type=float)
    parser.add_argument('--min_delta', default=1e-5, type=float)
    parser.add_argument('--weight_decay', default=0.0001, type=float)
    parser.add_argument('--init_method', default='he', type=str)
    parser.add_argument('--init_threshold', default=0.5, type=float)
    parser.add_argument('--threshold_ratio', default=0.5, type=float)

    # Training parameters
    parser.add_argument('--dataset', default='MSL', type=str)
    parser.add_argument('--num_epochs', default=15, type=int)
    parser.add_argument('--patience', default=3, type=int)
    parser.add_argument('--anormly_ratio', default=1, type=float)
    parser.add_argument('--device', default='0', type=str)
    parser.add_argument('--step', default=1, type=int)
    parser.add_argument('--mode', default='Train', type=str)

    return parser

def get_model_name(config):
    key_params = [
        ('bs', config['batch_size']),
        ('ws', config['window_size']),
        ('cn', config['cluster_num']),
        ('nu', f"{config['nu']:.3f}"),
        ('sf', f"{config['smoothing_factor']:.1f}"),
        ('lr', f"{config['lr']:.1e}"),
        ('wd', f"{config['weight_decay']:.1e}"),
        ('it', f"{config['init_threshold']:.1f}"),
        ('ep', config['num_epochs']),
        ('ds', config['dataset']),
        ('hs', '_'.join(map(str, config['hidden_structs']))),
        ('dl', '_'.join(map(str, config['dilations']))), 
    ]
    
    # convert parameter as string
    param_strs = [f"{k}{v}" for k, v in key_params]
    model_name = '_'.join(param_strs)
    
    return model_name

def main(args):
    config = vars(args)

    assert config['mode'] in ('Train', 'Test'), "Mode must be either 'Train' or 'Test'."
    assert config['init_method'] in ('he', 'xavier', 'generic'), "initialize method must be either 'he', 'xavier', 'generic'."
    assert config['objective'] in ('one-class', 'soft-boundary'), "Objective must be either 'one-class' or 'soft-boundary'."
    assert (0 < config['nu']) & (config['nu'] <= 1), "For hyperparameter nu, it must hold: 0 < nu <= 1."
    assert (0 < config['init_threshold']) & (config['init_threshold'] < 1), "For init_threshold, it must hold: 0 < init_threshold < 1."
    assert (0 < config['smoothing_factor']) & (config['smoothing_factor'] < 0.5), "For smoothing_factor, it must hold: 0 <= smoothing_factor <= 0.5."

    model_save_name = get_model_name(config)

    config['data_path'] = '../dataset/' + config['dataset']
    config['feature_size'] = set_feature_size(config['dataset'])
    
    print('------------ Options -------------')
    for k, v in sorted(config.items()):
        print('%s: %s' % (str(k), str(v)))
    print('-------------- End ----------------')
    
    config['model_save_name'] = os.path.join('..', 'results', model_save_name)
    print('model_save_name:', model_save_name)
    config['device'] = set_device(config['device'])
        
    # load and save best model weights
    results_dir = os.path.join('..', 'results')
    if not os.path.exists(results_dir): 
        os.makedirs(results_dir) 
        
    # -------------------------------------------------------------------------------------------

    train_loader = get_loader_segment(config['data_path'], 
                                    batch_size=config['batch_size'], 
                                    window_size=config['window_size'],
                                    mode='train',
                                    step=config['step'],
                                    dataset=config['dataset'])

    valid_loader = get_loader_segment(config['data_path'], 
                                    batch_size=config['batch_size'], 
                                    window_size=config['window_size'],
                                    mode='val',
                                    step=config['step'],
                                    dataset=config['dataset'])

    test_loader = get_loader_segment(config['data_path'], 
                                    batch_size=config['batch_size'], 
                                    window_size=config['window_size'],
                                    mode='test',
                                    step=config['step'],
                                    dataset=config['dataset'])

    thre_loader = get_loader_segment(config['data_path'], 
                                    batch_size=config['batch_size'], 
                                    window_size=config['window_size'],
                                    mode='thre',
                                    step=config['step'],
                                    dataset=config['dataset'])
        
    # -------------------------------------------------------------------------------------------
    
    """Training the Deep SVDD model"""
    model = Time_Series_Deep_SVDD(config).to(config['device'])
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model).to(config['device'])
        
    model.apply(he_init_normal)
    # initialize radius vector as much as number of clusters
    R = torch.zeros(config['cluster_num'], device=config['device'])
        
    if config['mode'] == 'Test':
        load_best_model_wts = torch.load('{}.pt'.format(config['model_save_name']))
        model.module.load_state_dict(load_best_model_wts['net_dict']) if torch.cuda.device_count() > 1 else model.load_state_dict(load_best_model_wts['net_dict'])
        R = torch.tensor(load_best_model_wts['R'], device=config['device'])
        
    # Create optimizer for model parameters excluding 'thre'
    optimizer = torch.optim.Adam(model.get_model_params(), lr=config['lr'], betas=(0.9, 0.999), weight_decay=config['weight_decay'], amsgrad=True)
    
    # Create optimizer for 'thre' parameter
    optimizer2 = torch.optim.Adam(model.get_thre_param(), lr=config['lr'], amsgrad=True)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: config['lr_lambda'] ** epoch)
    early_stopping_loss = EarlyStopping(patience=config['patience'], mode='min', min_delta=config['min_delta'])
    main_trainer(model, scheduler, optimizer, optimizer2, early_stopping_loss, config, train_loader, valid_loader, test_loader, thre_loader, R)                        

if __name__ == '__main__':
    parser = argparse.ArgumentParser('Training script', parents=[get_args_parser()])
    args = parser.parse_args()

    main(args)